{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MhoQ0WE77laV"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "_ckMIh7O7s6D"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jYysdyb-CaWM"
      },
      "source": [
        "# Индивидуальное обучение с tf.distribute.Strategy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S5Uhzt6vVIB2"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/distribute/custom_training\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />Смотрите на TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ru/tutorials/distribute/custom_training.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Запустите в Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ru/tutorials/distribute/custom_training.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />Изучайте код на GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ru/tutorials/distribute/custom_training.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Скачайте ноутбук</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8a2322303a3f"
      },
      "source": [
        "Note: Вся информация в этом разделе переведена с помощью русскоговорящего Tensorflow сообщества на общественных началах. Поскольку этот перевод не является официальным, мы не гарантируем что он на 100% аккуратен и соответствует [официальной документации на английском языке](https://www.tensorflow.org/?hl=en). Если у вас есть предложение как исправить этот перевод, мы будем очень рады увидеть pull request в [tensorflow/docs](https://github.com/tensorflow/docs) репозиторий GitHub. Если вы хотите помочь сделать документацию по Tensorflow лучше (сделать сам перевод или проверить перевод подготовленный кем-то другим), напишите нам на [docs-ru@tensorflow.org list](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-ru)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FbVhjPpzn6BM"
      },
      "source": [
        "В этом руководстве будет показано, как использовать [`tf.distribute.Strategy`](https://www.tensorflow.org/guide/distributed_training) с пользовательскими циклами обучения. Мы обучим простую модель CNN на наборе данных Fashion MNIST. Датасет Fashion MNIST содержит 60000 тренироваочных изображений размером 28x28 пикселей и 10000 тестовых изображений размером 28x28 пикселей.\n",
        "\n",
        "Мы будем использовать кастомные циклы для обучения нашей модели, потому что они дают нам большую гибкость и больший контроль над обучением. Более того, при использовании кастомных циклов, - легче отлаживать модель и цикл обучения."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dzLKpmZICaWN"
      },
      "outputs": [],
      "source": [
        "# Import TensorFlow\n",
        "import tensorflow as tf\n",
        "\n",
        "# Helper libraries\n",
        "import numpy as np\n",
        "import os\n",
        "\n",
        "print(tf.__version__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MM6W__qraV55"
      },
      "source": [
        "## Загрузка датасета Fashion MNIST"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7MqDQO0KCaWS"
      },
      "outputs": [],
      "source": [
        "fashion_mnist = tf.keras.datasets.fashion_mnist\n",
        "\n",
        "(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()\n",
        "\n",
        "# Добавление дополнительного измерения в массив -> новая размерность == (28, 28, 1)\n",
        "# Мы делаем это, потому что первый слой в нашей модели - сверточный\n",
        "# и требует ввода 4D (размер пакета, высота, ширина, каналы).\n",
        "# размер пакета(batch_size) будет добавлен позже.\n",
        "train_images = train_images[..., None]\n",
        "test_images = test_images[..., None]\n",
        "\n",
        "# Нормализуем пиксели изображения в диапазон [0, 1].\n",
        "train_images = train_images / np.float32(255)\n",
        "test_images = test_images / np.float32(255)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4AXoHhrsbdF3"
      },
      "source": [
        "## Создание стратегии для распределенния переменных и графа"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5mVuLZhbem8d"
      },
      "source": [
        "Как работает стратегия `tf.distribute.MirroredStrategy`?\n",
        "\n",
        "* Все переменные и граф модели вычисляются на репликах.\n",
        "* Входные данные равномерно распределяются по репликам.\n",
        "* Каждая реплика вычисляет потери и градиенты для полученных входных данных.\n",
        "* Градиенты синхронизируются по всем репликам путем их суммирования.\n",
        "* После синхронизации такое же обновление производится для копий переменных на каждой реплике.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F2VeZUWUj5S4"
      },
      "outputs": [],
      "source": [
        "# Если список устройств не указан в\n",
        "# конструкторе `tf.distribute.MirroredStrategy`, он будет определенн автоматически.\n",
        "strategy = tf.distribute.MirroredStrategy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZngeM_2o0_JO"
      },
      "outputs": [],
      "source": [
        "print ('Колтчество доступных устройств: {}'.format(strategy.num_replicas_in_sync))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k53F5I_IiGyI"
      },
      "source": [
        "## Настройка входного конвейера"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Qb6nDgxiN_n"
      },
      "source": [
        "Экспортируйте граф и переменные в плтформо-независимый формат с помощью `SavedModel`. После сохранения модели вы можете загрузить ее как с контекстным менеджером, так и без него."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jwJtsCQhHK-E"
      },
      "outputs": [],
      "source": [
        "BUFFER_SIZE = len(train_images)\n",
        "\n",
        "BATCH_SIZE_PER_REPLICA = 64\n",
        "GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync\n",
        "\n",
        "EPOCHS = 10"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J7fj3GskHC8g"
      },
      "source": [
        "Создайте датасеты и распределите их:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WYrMNNDhAvVl"
      },
      "outputs": [],
      "source": [
        "train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) \n",
        "test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) \n",
        "\n",
        "train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)\n",
        "test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bAXAo_wWbWSb"
      },
      "source": [
        "## Создание модели\n",
        "\n",
        "Создайте модель, используя `tf.keras.Sequential`. Также вы можете использовать наследование от `Model`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9ODch-OFCaW4"
      },
      "outputs": [],
      "source": [
        "def create_model():\n",
        "  model = tf.keras.Sequential([\n",
        "      tf.keras.layers.Conv2D(32, 3, activation='relu'),\n",
        "      tf.keras.layers.MaxPooling2D(),\n",
        "      tf.keras.layers.Conv2D(64, 3, activation='relu'),\n",
        "      tf.keras.layers.MaxPooling2D(),\n",
        "      tf.keras.layers.Flatten(),\n",
        "      tf.keras.layers.Dense(64, activation='relu'),\n",
        "      tf.keras.layers.Dense(10)\n",
        "    ])\n",
        "\n",
        "  return model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9iagoTBfijUz"
      },
      "outputs": [],
      "source": [
        "# Создайте директорию для сохранения чекпойнтов.\n",
        "checkpoint_dir = './training_checkpoints'\n",
        "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e-wlFFZbP33n"
      },
      "source": [
        "## Определение функции потерь\n",
        "\n",
        "Обычно на одной компьютере с 1-м GPU/CPU потери делятся на размер пакета входных данных(batch_size).\n",
        "\n",
        "*Как же рассчитать потери при использовании tf.distribute.Strategy?*\n",
        "\n",
        "* Например, предположим, что у вас 4 графических процессора и размер пакета 64. Один пакет ввода распределяется\n",
        "  на 4 реплики (4 графических процессора), каждая реплика получает входящий пакет размером 16.\n",
        "\n",
        "* Модель на каждой реплике(графический процессор) выполняет прямой проход с соответствующими входными данными и вычисляет потери. Теперь, вместо деления потерь на количество примеров в соответствующем вводе (BATCH_SIZE_PER_REPLICA = 16), потери следует разделить на GLOBAL_BATCH_SIZE (64).\n",
        "\n",
        "*Зачем это делать?*\n",
        "\n",
        "* Это необходимо сделать, потому что после расчета градиентов на каждой реплике они синхронизируются между репликами путем **суммирования**.\n",
        "\n",
        "*Как это сделать в TensorFlow?*\n",
        "* Если вы пишете свой собственный цикл обучения, как в этом руководстве, вам следует просуммировать потери для каждого примера и разделить сумму на GLOBAL_BATCH_SIZE:\n",
        "  `scale_loss = tf.reduce_sum(потери) * (1. / GLOBAL_BATCH_SIZE)` или вы можете использовать `tf.nn.compute_average_loss`, который берет среднюю величину потерь на пакет, веса(если они используются) и GLOBAL_BATCH_SIZE в качестве аргументов и возвращает нормализованные потери.\n",
        "\n",
        "* Если вы используете в своей модели потери регуляризации, вам необходимо масштабировать\n",
        "  величина потерь по количеству реплик. Вы можете сделать это с помощью функции `tf.nn.scale_regularization_loss`.\n",
        "\n",
        "* Использование `tf.reduce_mean` не рекомендуется. При расчете этим методом, потери делятся на фактический размер пакета реплик, который может меняться шаг за шагом.\n",
        "\n",
        "* Уменьшение и масштабирование выполняется автоматически в файлах keras `model.compile` и` model.fit`\n",
        "\n",
        "* При использовании классов `tf.keras.losses` (как в примере ниже) необходимо явно указать уменьшение потерь, которое может быть либо `NONE`, либо `SUM`. \n",
        "* `AUTO` и `SUM_OVER_BATCH_SIZE` запрещены при использовании с `tf.distribute.Strategy`. \n",
        "* `AUTO` запрещен, потому что пользователь должен четко указать, какое уменьшение  потерь он хочет использовать, и убедиться, что оно правильно работает в распределенной стратегии. `SUM_OVER_BATCH_SIZE` запрещен, потому что в текущей реализации он будет делиться только на размер пакета на реплике, а деление на количество реплик останется на усмотрение пользователя, что может быть легко упущено. Поэтому мы просим пользователя указать уменьшение потерь самостоятельно.\n",
        "* Если метки являются многомерными, их необходимо усреднить по количеству элементов в каждой выборке - `per_example_loss`. Например, если форма `predictions` - `(batch_size, H, W, n_classes) `, а `labels` - `(batch_size, H, W)`, вам нужно будет обновить переменную `per_example_loss`: ` per_example_loss / = tf.cast (tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)`\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R144Wci782ix"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  # Установите tf.keras.losses.Reduction в None, чтобы мы могли сделать уменьшение позже \n",
        "  # и разделить его на глобальный размер пакета.\n",
        "  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n",
        "      from_logits=True,\n",
        "      reduction=tf.keras.losses.Reduction.NONE)\n",
        "  def compute_loss(labels, predictions):\n",
        "    per_example_loss = loss_object(labels, predictions)\n",
        "    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w8y54-o9T2Ni"
      },
      "source": [
        "## Определение метрики для отслеживания потерь и точности\n",
        "\n",
        "Эти метрики отслеживают потери валидации, а также точность обучения и валидации. \n",
        "Вы можете использовать метод `.result()` для получения накопленной статистики в любое время."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zt3AHb46Tr3w"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
        "\n",
        "  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "      name='train_accuracy')\n",
        "  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "      name='test_accuracy')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iuKuNXPORfqJ"
      },
      "source": [
        "## Цикл обучения"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OrMmakq5EqeQ"
      },
      "outputs": [],
      "source": [
        "# модель, оптимайзер и чекпойнт должны быть созданы с использованием `strategy.scope`.\n",
        "with strategy.scope():\n",
        "  model = create_model()\n",
        "\n",
        "  optimizer = tf.keras.optimizers.Adam()\n",
        "\n",
        "  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3UX43wUu04EL"
      },
      "outputs": [],
      "source": [
        "def train_step(inputs):\n",
        "  images, labels = inputs\n",
        "\n",
        "  with tf.GradientTape() as tape:\n",
        "    predictions = model(images, training=True)\n",
        "    loss = compute_loss(labels, predictions)\n",
        "\n",
        "  gradients = tape.gradient(loss, model.trainable_variables)\n",
        "  optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
        "\n",
        "  train_accuracy.update_state(labels, predictions)\n",
        "  return loss \n",
        "\n",
        "def test_step(inputs):\n",
        "  images, labels = inputs\n",
        "\n",
        "  predictions = model(images, training=False)\n",
        "  t_loss = loss_object(labels, predictions)\n",
        "\n",
        "  test_loss.update_state(t_loss)\n",
        "  test_accuracy.update_state(labels, predictions)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gX975dMSNw0e"
      },
      "outputs": [],
      "source": [
        "# `strategy.run` реплицирует предоставленные расчеты и запускает их с распределенным вводом.\n",
        "@tf.function\n",
        "def distributed_train_step(dataset_inputs):\n",
        "  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))\n",
        "  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,\n",
        "                         axis=None)\n",
        "\n",
        "@tf.function\n",
        "def distributed_test_step(dataset_inputs):\n",
        "  return strategy.run(test_step, args=(dataset_inputs,))\n",
        "\n",
        "for epoch in range(EPOCHS):\n",
        "  # ЦИКЛ ОБУЧЕНИЯ\n",
        "  total_loss = 0.0\n",
        "  num_batches = 0\n",
        "  for x in train_dist_dataset:\n",
        "    total_loss += distributed_train_step(x)\n",
        "    num_batches += 1\n",
        "  train_loss = total_loss / num_batches\n",
        "\n",
        "  # ЦИКЛ ВАЛИДАЦИИ\n",
        "  for x in test_dist_dataset:\n",
        "    distributed_test_step(x)\n",
        "\n",
        "  if epoch % 2 == 0:\n",
        "    checkpoint.save(checkpoint_prefix)\n",
        "\n",
        "  template = (\"Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, \"\n",
        "              \"Test Accuracy: {}\")\n",
        "  print (template.format(epoch+1, train_loss,\n",
        "                         train_accuracy.result()*100, test_loss.result(),\n",
        "                         test_accuracy.result()*100))\n",
        "\n",
        "  test_loss.reset_states()\n",
        "  train_accuracy.reset_states()\n",
        "  test_accuracy.reset_states()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z1YvXqOpwy08"
      },
      "source": [
        "На что следует обратить внимание в приведенном выше примере:\n",
        "\n",
        "* Мы перебираем наборы `train_dist_dataset` и `test_dist_dataset`, используя конструкцию `for x in ...`\n",
        "* Масштабируемая потеря - это возвращаемое значение `distributed_train_step`. Это значение агрегируется по репликам с помощью вызова `tf.distribute.Strategy.reduce`, а затем по пакетам путем суммирования возвращаемого значения вызовов `tf.distribute.Strategy.reduce((tf.distribute.ReduceOp.SUM...`.\n",
        "* `tf.keras.Metrics`и должны быть обновлены внутри train_step и test_step, которые выполняется в `tf.distribute.Strategy.run`.\n",
        "* `tf.distribute.Strategy.run` возвращает результаты каждой локальной реплики в стратегии, и существует несколько способов получить результат. Вы можете выполнить `tf.distribute.Strategy.reduce`, чтобы получить агрегированное значение по всем репликам. Вы также можете выполнить `tf.distribute.Strategy.experimental_local_results`, чтобы получить список значений по каждой реплике и агрегировать его предпочитаемым способом.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-q5qp31IQD8t"
      },
      "source": [
        "## Восстановление последнего чекпойнта и тестирование"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WNW2P00bkMGJ"
      },
      "source": [
        "Модель, сохраненная с использованием `tf.distribute.Strategy` может быть восстановлена как с использованием `tf.distribute.Strategy`, так и без."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pg3B-Cw_cn3a"
      },
      "outputs": [],
      "source": [
        "eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "      name='eval_accuracy')\n",
        "\n",
        "new_model = create_model()\n",
        "new_optimizer = tf.keras.optimizers.Adam()\n",
        "\n",
        "test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7qYii7KUYiSM"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def eval_step(images, labels):\n",
        "  predictions = new_model(images, training=False)\n",
        "  eval_accuracy(labels, predictions)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LeZ6eeWRoUNq"
      },
      "outputs": [],
      "source": [
        "checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)\n",
        "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))\n",
        "\n",
        "for images, labels in test_dataset:\n",
        "  eval_step(images, labels)\n",
        "\n",
        "print ('Accuracy after restoring the saved model without strategy: {}'.format(\n",
        "    eval_accuracy.result()*100))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EbcI87EEzhzg"
      },
      "source": [
        "## Альтернативные способы перебора набора данных\n",
        "\n",
        "### Использование итераторов\n",
        "\n",
        "Если вы хотите перебрать определенное количество шагов, а не весь набор данных, вы можете создать итератор, используя вызов `iter` и явный вызов `next` на итераторе. Вы можете перебирать датасет как внутри, так и вне tf.function. Вот небольшой фрагмент, демонстрирующий итерацию набора данных вне tf.function с использованием итератора."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7c73wGC00CzN"
      },
      "outputs": [],
      "source": [
        "for _ in range(EPOCHS):\n",
        "  total_loss = 0.0\n",
        "  num_batches = 0\n",
        "  train_iter = iter(train_dist_dataset)\n",
        "\n",
        "  for _ in range(10):\n",
        "    total_loss += distributed_train_step(next(train_iter))\n",
        "    num_batches += 1\n",
        "  average_train_loss = total_loss / num_batches\n",
        "\n",
        "  template = (\"Epoch {}, Loss: {}, Accuracy: {}\")\n",
        "  print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))\n",
        "  train_accuracy.reset_states()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GxVp48Oy0m6y"
      },
      "source": [
        "### Итерация внутри tf.function\n",
        "\n",
        "Вы также можете перебирать весь `train_dist_dataset` внутри tf.function, используя конструкцию `for x in ... `или создавая итераторы, как мы делали выше. В приведенном ниже примере демонстрируется перенос одной эпохи обучения в tf.function и проход по `train_dist_dataset` внутри функции."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-REzmcXv00qm"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def distributed_train_epoch(dataset):\n",
        "  total_loss = 0.0\n",
        "  num_batches = 0\n",
        "  for x in dataset:\n",
        "    per_replica_losses = strategy.run(train_step, args=(x,))\n",
        "    total_loss += strategy.reduce(\n",
        "      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)\n",
        "    num_batches += 1\n",
        "  return total_loss / tf.cast(num_batches, dtype=tf.float32)\n",
        "\n",
        "for epoch in range(EPOCHS):\n",
        "  train_loss = distributed_train_epoch(train_dist_dataset)\n",
        "\n",
        "  template = (\"Epoch {}, Loss: {}, Accuracy: {}\")\n",
        "  print (template.format(epoch+1, train_loss, train_accuracy.result()*100))\n",
        "\n",
        "  train_accuracy.reset_states()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MuZGXiyC7ABR"
      },
      "source": [
        "### Отслеживание потерь при обучении по репликам\n",
        "\n",
        "Примечание. Как правило, вы должны использовать `tf.keras.Metrics` для отслеживания значений по выборке и избегать значений, которые были агрегированы в реплике.\n",
        "\n",
        "Мы **не** рекомендуем использовать `tf.keras.Metrics` для отслеживания потерь при обучении по разным репликам из-за выполнения масштабирования расчитанных потерь.\n",
        "\n",
        "Например, если вы выполняете учебное задание со следующими характеристиками:\n",
        "* Две реплики\n",
        "* На каждой реплике обрабатываются два образца\n",
        "* Полученные значения потерь: [2, 3] и [4, 5] на каждой реплике.\n",
        "* Глобальный размер пакета = 4\n",
        "\n",
        "При масштабировании потерь вы вычисляете значение потерь для пакета на каждой реплике, добавляя значения потерь и затем делите его на глобальный размер пакета. В данном случае: `(2 + 3) / 4 = 1,25` и` (4 + 5) / 4 = 2,25`.\n",
        "\n",
        "Если вы используете `tf.keras.Metrics` для отслеживания потерь в двух репликах, результат будет другим. В этом примере вы получаете `total`, равный 3,50, и `count`, равный 2, в результате чего `total`/`count` = 1,75 при вызове `result()` для метрики. Потери, рассчитанные с помощью `tf.keras.Metrics`, масштабируются с помощью дополнительного коэффициента, который равен количеству синхронизированных реплик."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xisYJaV9KZTN"
      },
      "source": [
        "### Руководства и примеры\n",
        "\n",
        "Вот несколько примеров использования стратегии распределния с пользовательскими циклами обучения:\n",
        "\n",
        "1. [Руководство по распределенному обучению](../../guide/distributed_training)\n",
        "2. [DenseNet](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/densenet/distributed_train.py) в примере используется `MirroredStrategy`.\n",
        "1. [BERT](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py) пример обучался с использованием `MirroredStrategy` и `TPUStrategy`.\n",
        "Этот пример особенно полезен для понимания того, как выполнять загрузку из контрольной точки и создавать периодические контрольные точки во время распределенного обучения.\n",
        "2. [NCF](https://github.com/tensorflow/models/blob/master/official/recommendation/ncf_keras_main.py) пример обучался с использованием `MirroredStrategy`, которая может быть включена установкой флага `keras_use_ctl`.\n",
        "3. [NMT](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/nmt_with_attention/distributed_train.py) пример обучался с использованием `MirroredStrategy`.\n",
        "\n",
        "Дополнительные примеры перечислены в [Руководстве по распределенной стратегии](../../guide/distributed_training.ipynb#examples_and_tutorials)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6hEJNsokjOKs"
      },
      "source": [
        "## Что дальше\n",
        "\n",
        "* Попробуйте новый API `tf.distribute.Strategy` на своих моделях.\n",
        "* Посетите [Раздел производительности](../../guide/function.ipynb) в руководстве, чтобы узнать больше о других стратегиях и [инструментах](../../guide/profiler.md), которые можно использовать для улучшения производительности ваших моделей TensorFlow."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "custom_training.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
